{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 文本数据处理的方法\n",
    "1. 利用tf.keras.preprocessing中的Tokenizer词典构建工具和tf.keras.utils.Sequence构建文本数据生成器管道，（较复杂https://zhuanlan.zhihu.com/p/67697840）\n",
    "2. 使用tf.data.Dataset搭配keras.layers.expermental.preprocessing.TextVectorization预处理层"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np \n",
    "import pandas as pd \n",
    "from matplotlib import pyplot as plt\n",
    "import tensorflow as tf\n",
    "from tensorflow.keras import models,layers,preprocessing,optimizers,losses,metrics\n",
    "from tensorflow.keras.layers.experimental.preprocessing import TextVectorization\n",
    "import re,string\n",
    "\n",
    "train_data_path = \"./data/imdb/train.csv\"\n",
    "test_data_path =  \"./data/imdb/test.csv\"\n",
    "\n",
    "MAX_WORDS = 10000  # 仅考虑最高频的10000个词\n",
    "MAX_LEN = 200  # 每个样本保留200个词的长度\n",
    "BATCH_SIZE = 20 \n",
    "\n",
    "\n",
    "#构建管道\n",
    "def split_line(line):\n",
    "    arr = tf.strings.split(line,\"\\t\")\n",
    "    label = tf.expand_dims(tf.cast(tf.strings.to_number(arr[0]),tf.int32),axis = 0)\n",
    "    text = tf.expand_dims(arr[1],axis = 0)\n",
    "    return (text,label)\n",
    "\n",
    "ds_train_raw =  tf.data.TextLineDataset(filenames = [train_data_path]) \\\n",
    "   .map(split_line,num_parallel_calls = tf.data.experimental.AUTOTUNE) \\\n",
    "   .shuffle(buffer_size = 1000).batch(BATCH_SIZE) \\\n",
    "   .prefetch(tf.data.experimental.AUTOTUNE)\n",
    "\n",
    "ds_test_raw = tf.data.TextLineDataset(filenames = [test_data_path]) \\\n",
    "   .map(split_line,num_parallel_calls = tf.data.experimental.AUTOTUNE) \\\n",
    "   .batch(BATCH_SIZE) \\\n",
    "   .prefetch(tf.data.experimental.AUTOTUNE)\n",
    "\n",
    "\n",
    "#构建词典\n",
    "def clean_text(text):\n",
    "    lowercase = tf.strings.lower(text)\n",
    "    stripped_html = tf.strings.regex_replace(lowercase, '<br />', ' ')\n",
    "    cleaned_punctuation = tf.strings.regex_replace(stripped_html,\n",
    "         '[%s]' % re.escape(string.punctuation),'')\n",
    "    return cleaned_punctuation\n",
    "\n",
    "vectorize_layer = TextVectorization(\n",
    "    standardize=clean_text,\n",
    "    split = 'whitespace',\n",
    "    max_tokens=MAX_WORDS-1, #有一个留给占位符\n",
    "    output_mode='int',\n",
    "    output_sequence_length=MAX_LEN)\n",
    "\n",
    "ds_text = ds_train_raw.map(lambda text,label: text)\n",
    "vectorize_layer.adapt(ds_text)\n",
    "print(vectorize_layer.get_vocabulary()[0:100])\n",
    "\n",
    "\n",
    "#单词编码\n",
    "ds_train = ds_train_raw.map(lambda text,label:(vectorize_layer(text),label)) \\\n",
    "    .prefetch(tf.data.experimental.AUTOTUNE)\n",
    "ds_test = ds_test_raw.map(lambda text,label:(vectorize_layer(text),label)) \\\n",
    "    .prefetch(tf.data.experimental.AUTOTUNE)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 定义模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 演示自定义模型范例，实际上应该优先使用Sequential或者函数式API\n",
    "tf.keras.backend.clear_session()\n",
    "\n",
    "class CnnModel(models.Model):\n",
    "    def __init__(self):\n",
    "        super(CnnModel, self).__init__()\n",
    "        \n",
    "    def build(self,input_shape):\n",
    "        self.embedding = layers.Embedding(MAX_WORDS,7,input_length=MAX_LEN)\n",
    "        self.conv_1 = layers.Conv1D(16, kernel_size= 5,name = \"conv_1\",activation = \"relu\")\n",
    "        self.pool_1 = layers.MaxPool1D(name = \"pool_1\")\n",
    "        self.conv_2 = layers.Conv1D(128, kernel_size=2,name = \"conv_2\",activation = \"relu\")\n",
    "        self.pool_2 = layers.MaxPool1D(name = \"pool_2\")\n",
    "        self.flatten = layers.Flatten()\n",
    "        self.dense = layers.Dense(1,activation = \"sigmoid\")\n",
    "        super(CnnModel,self).build(input_shape)\n",
    "    \n",
    "    def call(self, x):\n",
    "        x = self.embedding(x)\n",
    "        x = self.conv_1(x)\n",
    "        x = self.pool_1(x)\n",
    "        x = self.conv_2(x)\n",
    "        x = self.pool_2(x)\n",
    "        x = self.flatten(x)\n",
    "        x = self.dense(x)\n",
    "        return(x)\n",
    "    \n",
    "    # 用于显示Output Shape\n",
    "    def summary(self):\n",
    "        x_input = layers.Input(shape = MAX_LEN)\n",
    "        output = self.call(x_input)\n",
    "        model = tf.keras.Model(inputs = x_input,outputs = output)\n",
    "        model.summary()\n",
    "    \n",
    "model = CnnModel()\n",
    "model.build(input_shape =(None,MAX_LEN))\n",
    "model.summary()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 训练模型\n",
    "\n",
    "自定义方式"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#打印时间分割线\n",
    "@tf.function\n",
    "def printbar():\n",
    "    today_ts = tf.timestamp()%(24*60*60)\n",
    "    \n",
    "    hour = tf.cast(today_ts//3600+8,tf.int32)%tf.constant(24)\n",
    "    minite = tf.cast((today_ts%3600)//60,tf.int32)\n",
    "    second = tf.cast(tf.floor(today_ts%60),tf.int32)\n",
    "    \n",
    "    def timeformat(m):\n",
    "        if tf.strings.length(tf.strings.format(\"{}\",m))==1:\n",
    "            return(tf.strings.format(\"0{}\",m))\n",
    "        else:\n",
    "            return(tf.strings.format(\"{}\",m))\n",
    "    \n",
    "    timestring = tf.strings.join([timeformat(hour),timeformat(minite),\n",
    "                timeformat(second)],separator = \":\")\n",
    "    tf.print(\"==========\"*8+timestring)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "optimizer = optimizers.Nadam()\n",
    "loss_func = losses.BinaryCrossentropy()\n",
    "\n",
    "train_loss = metrics.Mean(name='train_loss')\n",
    "train_metric = metrics.BinaryAccuracy(name='train_accuracy')\n",
    "\n",
    "valid_loss = metrics.Mean(name='valid_loss')\n",
    "valid_metric = metrics.BinaryAccuracy(name='valid_accuracy')\n",
    "\n",
    "\n",
    "@tf.function\n",
    "def train_step(model, features, labels):\n",
    "    with tf.GradientTape() as tape:\n",
    "        predictions = model(features,training = True)\n",
    "        loss = loss_func(labels, predictions)\n",
    "    gradients = tape.gradient(loss, model.trainable_variables)\n",
    "    optimizer.apply_gradients(zip(gradients, model.trainable_variables))\n",
    "\n",
    "    train_loss.update_state(loss)\n",
    "    train_metric.update_state(labels, predictions)\n",
    "    \n",
    "\n",
    "@tf.function\n",
    "def valid_step(model, features, labels):\n",
    "    predictions = model(features,training = False)\n",
    "    batch_loss = loss_func(labels, predictions)\n",
    "    valid_loss.update_state(batch_loss)\n",
    "    valid_metric.update_state(labels, predictions)\n",
    "\n",
    "\n",
    "def train_model(model,ds_train,ds_valid,epochs):\n",
    "    for epoch in tf.range(1,epochs+1):\n",
    "        \n",
    "        for features, labels in ds_train:\n",
    "            train_step(model,features,labels)\n",
    "\n",
    "        for features, labels in ds_valid:\n",
    "            valid_step(model,features,labels)\n",
    "        \n",
    "        #此处logs模板需要根据metric具体情况修改\n",
    "        logs = 'Epoch={},Loss:{},Accuracy:{},Valid Loss:{},Valid Accuracy:{}' \n",
    "        \n",
    "        if epoch%1==0:\n",
    "            printbar()\n",
    "            tf.print(tf.strings.format(logs,\n",
    "            (epoch,train_loss.result(),train_metric.result(),valid_loss.result(),valid_metric.result())))\n",
    "            tf.print(\"\")\n",
    "        \n",
    "        train_loss.reset_states()\n",
    "        valid_loss.reset_states()\n",
    "        train_metric.reset_states()\n",
    "        valid_metric.reset_states()\n",
    "\n",
    "train_model(model,ds_train,ds_test,epochs = 6)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 评估模型\n",
    "通过自定义训练循环训练的模型没有经过编译，无法直接使用model.evaluate(ds_valid)方法"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def evaluate_model(model,ds_valid):\n",
    "    for features, labels in ds_valid:\n",
    "         valid_step(model,features,labels)\n",
    "    logs = 'Valid Loss:{},Valid Accuracy:{}' \n",
    "    tf.print(tf.strings.format(logs,(valid_loss.result(),valid_metric.result())))\n",
    "    \n",
    "    valid_loss.reset_states()\n",
    "    train_metric.reset_states()\n",
    "    valid_metric.reset_states()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "evaluate_model(model,ds_test)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 使用模型\n",
    "可以使用以下方法:\n",
    "\n",
    "    1. model.predict(ds_test)\n",
    "    2. model(x_test)\n",
    "    3. model.call(x_test)\n",
    "    4. model.predict_on_batch(x_test)\n",
    "推荐优先使用model.predict(ds_test)方法，既可以对Dataset，也可以对Tensor使用。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.predict(ds_test)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for x_test,_ in ds_test.take(1):\n",
    "    print(model(x_test))\n",
    "    #以下方法等价：\n",
    "    #print(model.call(x_test))\n",
    "    #print(model.predict_on_batch(x_test))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
